# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Standard library imports
from typing import List, Optional, Tuple

# Third-party imports
import mxnet as mx
from mxnet import gluon
import numpy as np

# First-party imports
from gluonts.core.component import DType, validated
from gluonts.model.common import Tensor
from gluonts.mx.block.decoder import Seq2SeqDecoder
from gluonts.mx.block.enc2dec import Seq2SeqEnc2Dec
from gluonts.mx.block.encoder import Seq2SeqEncoder
from gluonts.mx.block.feature import FeatureEmbedder
from gluonts.mx.block.quantile_output import QuantileOutput
from gluonts.mx.block.scaler import MeanScaler, NOPScaler
from gluonts.mx.distribution import DistributionOutput
from gluonts.support.util import weighted_average


class ForkingSeq2SeqNetworkBase(gluon.HybridBlock):
    """
    Base network for the :class:`ForkingSeq2SeqEstimator`.

    Parameters
    ----------
    encoder: Seq2SeqEncoder
        encoder block.
    enc2dec: Seq2SeqEnc2Dec
        encoder to decoder mapping block.
    decoder: Seq2SeqDecoder
        decoder block.
    quantile_output
        quantile output
    distr_output
        distribution output
    context_length: int,
        length of the encoding sequence.
    cardinality: List[int],
        number of values of each categorical feature.
    embedding_dimension: List[int],
        dimension of the embeddings for categorical features.
    scaling
        Whether to automatically scale the target values. (default: True)
    scaling_decoder_dynamic_feature
        Whether to automatically scale the dynamic features for the decoder. (default: False)
    dtype
        (default: np.float32)
    num_forking: int,
        decides how much forking to do in the decoder. 1 reduces to seq2seq and enc_len reduces to MQ-C(R)NN.
    kwargs: dict
        dictionary of Gluon HybridBlock parameters
    """

    @validated()
    def __init__(
        self,
        encoder: Seq2SeqEncoder,
        enc2dec: Seq2SeqEnc2Dec,
        decoder: Seq2SeqDecoder,
        context_length: int,
        cardinality: List[int],
        embedding_dimension: List[int],
        distr_output: Optional[DistributionOutput] = None,
        quantile_output: Optional[QuantileOutput] = None,
        scaling: bool = True,
        scaling_decoder_dynamic_feature: bool = False,
        dtype: DType = np.float32,
        num_forking: Optional[int] = None,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        assert (distr_output is None) != (quantile_output is None)

        self.encoder = encoder
        self.enc2dec = enc2dec
        self.decoder = decoder
        self.distr_output = distr_output
        self.quantile_output = quantile_output
        self.scaling = scaling
        self.scaling_decoder_dynamic_feature = scaling_decoder_dynamic_feature
        self.dtype = dtype
        self.num_forking = (
            num_forking if num_forking is not None else context_length
        )

        if self.scaling:
            self.scaler = MeanScaler()
        else:
            self.scaler = NOPScaler()

        if self.scaling_decoder_dynamic_feature:
            self.scaler_decoder_dynamic_feature = MeanScaler(axis=1)
        else:
            self.scaler_decoder_dynamic_feature = NOPScaler(axis=1)

        with self.name_scope():
            if self.quantile_output:
                self.quantile_proj = self.quantile_output.get_quantile_proj()
                self.loss = self.quantile_output.get_loss()
            else:
                assert self.distr_output is not None
                self.distr_args_proj = self.distr_output.get_args_proj()
            self.embedder = FeatureEmbedder(
                cardinalities=cardinality,
                embedding_dims=embedding_dimension,
                dtype=self.dtype,
            )

    # this method connects the sub-networks and returns the decoder output
    def get_decoder_network_output(
        self,
        F,
        past_target: Tensor,
        past_feat_dynamic: Tensor,
        future_feat_dynamic: Tensor,
        feat_static_cat: Tensor,
        past_observed_values: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        """
        Parameters
        ----------
        F: mx.symbol or mx.ndarray
            Gluon function space
        past_target: Tensor
            shape (batch_size, encoder_length, 1)
        past_feat_dynamic
            shape (batch_size, encoder_length, num_past_feat_dynamic)
        future_feat_dynamic
            shape (batch_size, num_forking, decoder_length, num_feat_dynamic)
        feat_static_cat
            shape (batch_size, num_feat_static_cat)
        past_observed_values: Tensor
            shape (batch_size, encoder_length, 1)
        Returns
        -------
        decoder output tensor of size (batch_size, num_forking, dec_len, decoder_mlp_dim_seq[0])
        """

        # scale shape: (batch_size, 1, 1)
        scaled_past_target, scale = self.scaler(
            past_target, past_observed_values
        )

        # (batch_size, sum(embedding_dimension) = num_feat_static_cat)
        embedded_cat = self.embedder(feat_static_cat)

        # in addition to embedding features, use the log scale as it can help prediction too
        # (batch_size, num_feat_static = sum(embedding_dimension) + 1)
        feat_static_real = F.concat(embedded_cat, F.log(scale), dim=1)

        # Passing past_observed_values as a feature would allow the network to
        # make that distinction and possibly ignore the masked values.
        past_feat_dynamic_extended = F.concat(
            past_feat_dynamic, past_observed_values, dim=-1
        )

        # arguments: target, static_features, dynamic_features
        # enc_output_static shape: (batch_size, channels_seq[-1] + 1)
        # enc_output_dynamic shape: (batch_size, encoder_length, channels_seq[-1] + 1)
        enc_output_static, enc_output_dynamic = self.encoder(
            scaled_past_target, feat_static_real, past_feat_dynamic_extended
        )

        # TODO: This assumes that future_feat_dynamic has no missing values
        # TODO: Output the scale as well to be used by the decoder
        scaled_future_feat_dynamic, _ = self.scaler_decoder_dynamic_feature(
            future_feat_dynamic, F.ones_like(future_feat_dynamic)
        )

        # arguments: encoder_output_static, encoder_output_dynamic, future_features
        # dec_input_static shape: (batch_size, channels_seq[-1] + 1)
        # dec_input_dynamic shape:(batch_size, num_forking, channels_seq[-1] + 1 + decoder_length * num_feat_dynamic)
        dec_input_static, dec_input_dynamic = self.enc2dec(
            enc_output_static,
            # slice axis 1 from encoder_length = context_length to num_forking
            enc_output_dynamic.slice_axis(
                axis=1, begin=-self.num_forking, end=None
            ),
            scaled_future_feat_dynamic,
        )

        # arguments: dynamic_input, static_input
        # TODO: optimize what we pass to the decoder for the prediction case,
        #  where we we only need to pass the encoder output for the last time step
        dec_output = self.decoder(dec_input_dynamic, dec_input_static)

        # the output shape should be: (batch_size, num_forking, dec_len, decoder_mlp_dim_seq[0])
        return dec_output, scale


class ForkingSeq2SeqTrainingNetwork(ForkingSeq2SeqNetworkBase):
    # noinspection PyMethodOverriding
    def hybrid_forward(
        self,
        F,
        past_target: Tensor,
        future_target: Tensor,
        past_feat_dynamic: Tensor,
        future_feat_dynamic: Tensor,
        feat_static_cat: Tensor,
        past_observed_values: Tensor,
        future_observed_values: Tensor,
    ) -> Tensor:
        """
        Parameters
        ----------
        F: mx.symbol or mx.ndarray
            Gluon function space
        past_target: Tensor
            shape (batch_size, encoder_length, 1)
        future_target: Tensor
            shape (batch_size, num_forking, decoder_length)
        past_feat_dynamic
            shape (batch_size, encoder_length, num_past_feat_dynamic)
        future_feat_dynamic
            shape (batch_size, num_forking, decoder_length, num_feat_dynamic)
        feat_static_cat
            shape (batch_size, num_feat_static_cat)
        past_observed_values: Tensor
            shape (batch_size, encoder_length, 1)
        future_observed_values: Tensor
            shape (batch_size, num_forking, decoder_length)

        Returns
        -------
        loss with shape (batch_size, prediction_length)
        """
        # shape: (batch_size, num_forking, decoder_length, decoder_mlp_dim_seq[0])
        dec_output, scale = self.get_decoder_network_output(
            F,
            past_target,
            past_feat_dynamic,
            future_feat_dynamic,
            feat_static_cat,
            past_observed_values,
        )

        if self.quantile_output is not None:
            # shape: (batch_size, num_forking, decoder_length, len(quantiles))
            dec_dist_output = self.quantile_proj(dec_output)
            # shape: (batch_size, num_forking, decoder_length = prediction_length)
            loss = self.loss(future_target, dec_dist_output)
        else:
            assert self.distr_output is not None
            distr_args = self.distr_args_proj(dec_output)
            distr = self.distr_output.distribution(
                distr_args, scale=scale.expand_dims(axis=1)
            )
            loss = distr.loss(future_target)

        # mask the loss based on observed indicator
        # shape: (batch_size, decoder_length)
        weighted_loss = weighted_average(
            F=F, x=loss, weights=future_observed_values, axis=1
        )

        return weighted_loss


class ForkingSeq2SeqPredictionNetwork(ForkingSeq2SeqNetworkBase):
    # noinspection PyMethodOverriding
    def hybrid_forward(
        self,
        F,
        past_target: Tensor,
        past_feat_dynamic: Tensor,
        future_feat_dynamic: Tensor,
        feat_static_cat: Tensor,
        past_observed_values: Tensor,
    ) -> Tensor:
        """
        Parameters
        ----------
        F: mx.symbol or mx.ndarray
            Gluon function space
        past_target: Tensor
             shape (batch_size, encoder_length, 1)
        feat_static_cat
            shape (batch_size, num_feat_static_cat)
        past_feat_dynamic
            shape (batch_size, encoder_length, num_past_feat_dynamic)
        future_feat_dynamic
            shape (batch_size, num_forking, decoder_length, num_feat_dynamic)
        past_observed_values: Tensor
            shape (batch_size, encoder_length, 1)

        Returns
        -------
        prediction tensor with shape (batch_size, prediction_length)
        """

        # shape: (batch_size, num_forking, decoder_length, decoder_mlp_dim_seq[0])
        dec_output, _ = self.get_decoder_network_output(
            F,
            past_target,
            past_feat_dynamic,
            future_feat_dynamic,
            feat_static_cat,
            past_observed_values,
        )

        # We only care about the output of the decoder for the last time step
        # shape: (batch_size, decoder_length, decoder_mlp_dim_seq[0])
        fcst_output = F.slice_axis(dec_output, axis=1, begin=-1, end=None)
        fcst_output = F.squeeze(fcst_output, axis=1)

        # shape: (batch_size, len(quantiles), decoder_length = prediction_length)
        predictions = self.quantile_proj(fcst_output).swapaxes(2, 1)

        return predictions


class ForkingSeq2SeqDistributionPredictionNetwork(ForkingSeq2SeqNetworkBase):
    # noinspection PyMethodOverriding
    def hybrid_forward(
        self,
        F,
        past_target: Tensor,
        past_feat_dynamic: Tensor,
        future_feat_dynamic: Tensor,
        feat_static_cat: Tensor,
        past_observed_values: Tensor,
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """
        Parameters
        ----------
        F: mx.symbol or mx.ndarray
            Gluon function space
        past_target: Tensor
             shape (batch_size, encoder_length, 1)
        feat_static_cat
            shape (batch_size, encoder_length, num_feature_static_cat)
        past_feat_dynamic
            shape (batch_size, encoder_length, num_feature_dynamic)
        future_feat_dynamic
            shape (batch_size, num_forking, decoder_length, num_feature_dynamic)
        past_observed_values: Tensor
            shape (batch_size, encoder_length, 1)
        Returns
        -------
        distr_args: the parameters of distribution
        loc: an array of zeros with the same shape of scale
        scale:
        """

        dec_output, scale = self.get_decoder_network_output(
            F,
            past_target,
            past_feat_dynamic,
            future_feat_dynamic,
            feat_static_cat,
            past_observed_values,
        )
        fcst_output = F.slice_axis(dec_output, axis=1, begin=-1, end=None)
        fcst_output = F.squeeze(fcst_output, axis=1)
        distr_args = self.distr_args_proj(fcst_output)

        loc = F.zeros_like(scale)
        return distr_args, loc, scale
